# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import json

from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
from vllm.profiler.utils import TablePrinter, indent_string


def flatten_entries(entry_cls, profile_dict: dict):
    entries_and_depth = []

    def get_entries(node, curr_depth=0):
        entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))

        for child in node["children"]:
            get_entries(
                child,
                curr_depth=curr_depth + 1,
            )

    for root in profile_dict:
        get_entries(root)

    return entries_and_depth


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--json-trace",
        type=str,
        required=True,
        help="json trace file output by examples/offline_inference/profiling.py",
    )
    parser.add_argument(
        "--phase",
        type=str,
        required=True,
        help="The phase to print the table for. This is either"
        "prefill or decode_n, where n is the decode step "
        "number",
    )
    parser.add_argument(
        "--table",
        type=str,
        choices=["summary", "model"],
        default="summary",
        help="Which table to print, the summary table or the layerwise model table",
    )

    args = parser.parse_args()

    with open(args.json_trace) as f:
        profile_data = json.load(f)

    assert args.phase in profile_data, (
        f"Cannot find phase {args.phase} in profile data. Choose one among"
        f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}"
    )  # noqa

    if args.table == "summary":
        entries_and_depths = flatten_entries(
            SummaryStatsEntry, profile_data[args.phase]["summary_stats"]
        )
        column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15)
    elif args.table == "model":
        entries_and_depths = flatten_entries(
            ModelStatsEntry, profile_data[args.phase]["model_stats"]
        )
        column_widths = dict(
            name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60
        )

    # indent entry names based on the depth
    entries = []
    for entry, depth in entries_and_depths:
        entry.name = indent_string(
            entry.name,
            indent=depth,
            indent_style=lambda indent: "|" + "-" * indent + " ",
        )
        entries.append(entry)

    TablePrinter(type(entries[0]), column_widths).print_table(entries)
